package net.gazhi.delonix.core.web;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.springframework.validation.AbstractBindingResult;
import org.springframework.validation.BindingResult;
import org.springframework.validation.FieldError;
import org.springframework.validation.ObjectError;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;

/**
 * 在调用 MVC Controller 方法之后，对 errors 进行排序，改善错误提示的用户体验
 * 
 * @author Jeffrey Lin
 */
public class SortErrorsInterceptor extends HandlerInterceptorAdapter {

	@Override
	public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {
		super.postHandle(request, response, handler, modelAndView);
		if (modelAndView == null) {
			return;
		}
		Map<String, Object> map = modelAndView.getModel();
		for (Entry<String, Object> entry : map.entrySet()) {
			String key = entry.getKey();
			Object val = entry.getValue();
			if (val instanceof AbstractBindingResult) {
				this.sortErrors((AbstractBindingResult) val, key, modelAndView);
			}
		}
	}

	/**
	 * 按照 form 中定义的顺序，对 fieldErrors 进行排序<br>
	 * 使用 getDeclaredFields，不考虑继承的情况
	 * 
	 * @param bindingResult
	 * @param resultKey
	 * @param mav
	 */
	private void sortErrors(AbstractBindingResult bindingResult, String resultKey, ModelAndView mav) {
		List<FieldError> fieldErrors = bindingResult.getFieldErrors();
		if (fieldErrors.size() == 0) {
			return;
		}
		String formKey = resultKey.substring(BindingResult.class.getName().length() + 1);
		Object form = mav.getModelMap().get(formKey);
		if (form == null) {
			return;
		}
		List<FieldError> sortedErrors = new ArrayList<FieldError>(fieldErrors.size());
		for (Field field : form.getClass().getDeclaredFields()) {
			for (FieldError fieldError : fieldErrors) {
				if (field.getName().equals(fieldError.getField())) {
					sortedErrors.add(fieldError);
				}
			}
		}
		List<ObjectError> errors = this.getErrorsByField(bindingResult);
		errors.removeAll(sortedErrors);
		errors.addAll(0, sortedErrors);
	}

	/**
	 * 通过反射机制强行获取私有的 errors
	 * 
	 * @param bindingResult
	 * @return
	 */
	@SuppressWarnings("unchecked")
	private List<ObjectError> getErrorsByField(AbstractBindingResult bindingResult) {
		try {
			Field field = AbstractBindingResult.class.getDeclaredField("errors");
			field.setAccessible(true);
			return (List<ObjectError>) field.get(bindingResult);
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
	}

}